#!/pxrpythonsubst
#
# Copyright 2016 Pixar
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from future.utils import iteritems
from builtins import zip

from pxr import Gf

from maya import cmds
from maya.api import OpenMaya as OM
from maya import standalone

import os
import unittest
import pprint

import fixturesUtils

class testUsdImportXforms(unittest.TestCase):

    EPSILON = 1e-6

    @classmethod
    def setUpClass(cls):
        cls.inputPath = fixturesUtils.setUpClass(__file__)

    @classmethod
    def tearDownClass(cls):
        standalone.uninitialize()

    def setUp(cls):
        # Create a new file so each test case starts with a fresh state.
        cmds.file(new=1, f=1)        

    def _GetMayaTransform(self, transformName):
        selectionList = OM.MSelectionList()
        selectionList.add(transformName)
        mObj = selectionList.getDependNode(0)

        return OM.MFnTransform(mObj)

    def testImportInverseXformOpsOnly(self):
        """
        Tests that importing a USD cube mesh that has XformOps on it all tagged
        as inverse ops results in the correct transform when imported into Maya.
        """
        usdFile = os.path.join(self.inputPath, "UsdImportXformsTest", "UsdImportXformsTest.usda")
        cmds.usdImport(file=usdFile, shadingMode='none')

        mayaTransform = self._GetMayaTransform('InverseOpsOnlyCube')
        transformationMatrix = mayaTransform.transformation()

        expectedTranslation = [-1.0, -2.0, -3.0]
        actualTranslation = list(
            transformationMatrix.translation(OM.MSpace.kTransform))
        self.assertTrue(
            Gf.IsClose(expectedTranslation, actualTranslation, self.EPSILON))

        expectedRotation = [0.0, 0.0, Gf.DegreesToRadians(-45.0)]
        actualRotation = list(transformationMatrix.rotation())
        self.assertTrue(
            Gf.IsClose(expectedRotation, actualRotation, self.EPSILON))

        expectedScale = [2.0, 2.0, 2.0]
        actualScale = list(transformationMatrix.scale(OM.MSpace.kTransform))
        self.assertTrue(
            Gf.IsClose(expectedScale, actualScale, self.EPSILON))
        
    def testImportMayaXformVariations(self):
        """
        Tests that all combinations of the various maya xform pieces will
        import correctly
        """
        
        # create a bunch of transforms with varying transform attrs set
        from random import Random
        import itertools
        import os
        import pprint
        
        ATTRS = {
            'translate': (.01, 5),
            # we do rotates separately, so we can see "rotateY" and "rotateXYZ" ops
            'rotateX': (.01, 359.99),
            'rotateY': (.01, 359.99),
            'rotateZ': (.01, 359.99),
            'scale': (1.01, 2.0),
            'shear': (1.01, 2.0),
            'rotateOrder': (1, 5), 
            # it seems that internally rotateAxis is stored as a quaternion...
            # so to ensure proper roundtripping, keep values 0 < x < 90
            'rotateAxis': (.01, 89.99),
            'rotatePivot': (.01, 5),
            'scalePivot': (.01, 5),
            'rotatePivotTranslate': (.01, 5),
            'scalePivotTranslate': (.01, 5),
        }
        
        rand = Random(3)

        allNodes = []
        allExpected = {}
        
        topPrim = cmds.createNode('transform', name='topPrim')
        
        # Iterate through all combinations of whether each attr in ATTRS is set or not
        for i, enabledArray in enumerate(itertools.product((False, True), repeat=len(ATTRS))):
            # name will be like: mayaXform_000111010001
            node = 'mayaXform_{}'.format(''.join(str(int(x)) for x in enabledArray))
            node = cmds.createNode('transform', name=node, parent=topPrim)
            attrVals = {}
            allNodes.append(node)
            allExpected[node] = attrVals
            for enabled, (attr, (valMin, valMax)) in zip(enabledArray, iteritems(ATTRS)):
                if not enabled:
                    if attr in ('rotateOrder', 'rotateX', 'rotateY', 'rotateZ'):
                        attrVals[attr] = 0
                    elif attr == 'scale':
                        attrVals[attr] = (1, 1, 1)
                    else:
                        attrVals[attr] = (0, 0, 0)
                else:
                    if attr == 'rotateOrder':
                        # 1 - 5 because 0 (xyz) would correspond to "not enabled"
                        val = rand.randint(1, 5)
                    elif attr in ('rotateX', 'rotateY', 'rotateZ'):
                        val = rand.uniform(valMin, valMax)
                    else:
                        val = (rand.uniform(valMin, valMax),
                            rand.uniform(valMin, valMax),
                            rand.uniform(valMin, valMax))
                    attrVals[attr] = val
                    #node.setAttr(attr, val)
                    if isinstance(val, tuple):
                        cmds.setAttr("{}.{}".format(node, attr), *val)
                    else:
                        cmds.setAttr("{}.{}".format(node, attr), val)
        
        # Now write out a usd file with all our xforms...
        cmds.select(allNodes)
        usdPath = os.path.abspath('UsdImportMayaXformVariationsTest.usdc')
        cmds.usdExport(selection=1, file=usdPath)
        
        # Now import, and make sure it round-trips as expected
        cmds.file(new=1, f=1)
        cmds.usdImport(file=usdPath)
        for node, attrVals in iteritems(allExpected):
            # if only one (or less) of the three rotates is non-zero, then
            # the rotate order doesn't matter...
            nonZeroRotates = [attrVals['rotate' + dir] != 0 for dir in 'XYZ']
            skipRotateOrder = sum(int(x) for x in nonZeroRotates) <= 1 
            
            for attr, expectedVal in iteritems(attrVals):
                if attr == 'rotateOrder' and skipRotateOrder:
                    continue
                attrName = "{}.{}".format(node, attr)
                actualVal = cmds.getAttr(attrName)
                if not isinstance(expectedVal, tuple):
                    expectedVal = (expectedVal,)
                    actualVal = (actualVal,)
                else:
                    # cmds.getAttr('persp.scale') returns [(0, 0, 0)]... weird
                    actualVal = actualVal[0]
                for expected, actual in zip(expectedVal, actualVal):
                    try:
                        self.assertAlmostEqual(expected, actual,
                            msg="{} - expected {}, got {} (diff: {}".format(
                                attrName, expected, actual, abs(expected - actual)),
                            delta=1e-4)
                    except Exception:
                        print("full failed xform:")
                        pprint.pprint(attrVals)
                        raise
                    
    def testImportXformsRotateAxis(self):
        """
        Tests that importing xforms that have a rotateAxis with rotate order other than just XYZ
        still imports correctly
        """
        usdFile = os.path.join(self.inputPath, "UsdImportXformsTest", "UsdImportXformsTestRotateAxis.usda")
        cmds.usdImport(file=usdFile, shadingMode='none')
        
        expectedRotates = {
            'X': (60, 0, 0),
            'Y': (0, 60, 0),
            'Z': (0, 0, 60),
            'XYZ': (-120, 60, 0),
            'YZX': (-106.1021138, 25.6589063, 56.3099325),
            'ZXY': (120, -60, 0),
            'XZY': (-120, -60, 0),
            'YXZ': (106.1021138, 25.6589063, -56.3099325),
            'ZYX': (-106.1021138, -25.6589063, -56.3099325),
        }
        expectedScale = (.5, .5, .5)
        expectedTranslation = (1.0, 2.0, 3.0)
        
        for rotOrderName, expectedRotation in iteritems(expectedRotates):
            mayaTransform = self._GetMayaTransform(rotOrderName)
            transformationMatrix = mayaTransform.transformation()
    
            actualTranslation = list(
                transformationMatrix.translation(OM.MSpace.kTransform))
            self.assertTrue(
                Gf.IsClose(expectedTranslation, actualTranslation, self.EPSILON))
    
            expectedRotation = [Gf.DegreesToRadians(x) for x in expectedRotation]
            actualRotation = transformationMatrix.rotationOrientation().asEulerRotation()
            actualRotation = list(actualRotation)
            #print rotOrderName, actualRotation
            self.assertTrue(
                Gf.IsClose(expectedRotation, actualRotation, self.EPSILON))
    
            actualScale = list(transformationMatrix.scale(OM.MSpace.kTransform))
            self.assertTrue(
                Gf.IsClose(expectedScale, actualScale, self.EPSILON))


    def testPivot(self):
        """
        Tests that pivotPosition attribute doesn't interfere with the matrix
        that we get in maya when importing a usd file.
        """
        def _usdToMayaPath(usdPath):
            return str(usdPath).replace('/', '|')
        from maya import cmds
        usdFile = os.path.join(self.inputPath, "UsdImportXformsTest", "pivotTests.usda")
        from pxr import Usd, UsdGeom
        stage = Usd.Stage.Open(usdFile)
        xformCache = UsdGeom.XformCache()

        cmds.usdImport(file=os.path.abspath(usdFile), primPath='/World')

        usdPaths = [
                '/World/anim/chars/SomeCharacter/Geom/Face/Eyes/LEye',
                '/World/anim/chars/SomeCharacter/Geom/Face/Eyes/LEye/Sclera_sbdv',
                '/World/anim/chars/SomeCharacter/Geom/Face/Eyes/REye/Sclera_sbdv',
                '/World/anim/chars/SomeCharacter/Geom/Hair/HairStandin/Hair/Hair_sbdv',
                '/World/anim/chars/SomeCharacter/Geom/Hair/HairStandin/Hair/HairFrontPiece_sbdv',
                ]

        for usdPath in usdPaths:
            usdMatrix = xformCache.GetLocalToWorldTransform(stage.GetPrimAtPath(usdPath))
            mayaPath = _usdToMayaPath(usdPath)
            mayaMatrix = Gf.Matrix4d(*cmds.xform(mayaPath, query=True, matrix=True, worldSpace=True))

            print('testing matrix at', usdPath)
            self.assertTrue(Gf.IsClose(
                usdMatrix.ExtractTranslation(), 
                mayaMatrix.ExtractTranslation(), 
                self.EPSILON))

if __name__ == '__main__':
    unittest.main(verbosity=2)
